{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "238520e0",
   "metadata": {},
   "source": [
    "# Multimodal Search Using CLIP"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a3590f0e",
   "metadata": {},
   "source": [
    "## Introduction\n",
    "\n",
    "This notebook demonstrates how SuperDuperDB can perform multimodal searches using the `VectorIndex`. It highlights SuperDuperDB's flexibility in integrating different models for vectorizing diverse queries during search and inference. In this example, we utilize the [CLIP multimodal architecture](https://openai.com/research/clip).\n",
    "\n",
    "Real life use cases could be vectorizing diverse things like images and searching it efficiently. "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "40272d6a2681c8e8",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "## Prerequisites\n",
    "\n",
    "Before starting, make sure you have the required libraries installed. Run the following commands:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5ebe1497",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install superduperdb\n",
    "!pip install ipython openai-clip\n",
    "!pip install -U datasets"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b2f94ae8",
   "metadata": {},
   "source": [
    "## Connect to datastore \n",
    "\n",
    "First, we need to establish a connection to a MongoDB datastore via SuperDuperDB. You can configure the `MongoDB_URI` based on your specific setup. \n",
    "Here are some examples of MongoDB URIs:\n",
    "\n",
    "* For testing (default connection): `mongomock://test`\n",
    "* Local MongoDB instance: `mongodb://localhost:27017`\n",
    "* MongoDB with authentication: `mongodb://superduper:superduper@mongodb:27017/documents`\n",
    "* MongoDB Atlas: `mongodb+srv://<username>:<password>@<atlas_cluster>/<database>`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2b5ef986",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from superduperdb import superduper\n",
    "from superduperdb.backends.mongodb import Collection\n",
    "\n",
    "mongodb_uri = os.getenv(\"MONGODB_URI\", \"mongomock://test\")\n",
    "db = superduper(mongodb_uri, artifact_store='filesystem://./models/')\n",
    "\n",
    "# SuperDuperDB, now handles your MongoDB database\n",
    "# It just super dupers your database \n",
    "db = superduper(mongodb_uri, artifact_store='filesystem://.data')\n",
    "\n",
    "collection = Collection('multimodal')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6cd6d6b0",
   "metadata": {},
   "source": [
    "## Load Dataset\n",
    "\n",
    "For simplicity and interactivity, we'll use a subset of the [Tiny-Imagenet dataset](https://paperswithcode.com/dataset/tiny-imagenet). The processes shown here can be applied to larger datasets with higher-resolution images. If working with larger datasets, especially with high-resolution images, it's recommended to use a machine with a GPU for efficiency.\n",
    "\n",
    "To insert images into the database, we'll use the `Encoder`-`Document` framework. This framework allows saving Python class instances as blobs in the `Datalayer` and retrieving them as Python objects. SuperDuperDB comes with built-in support for `PIL.Image` instances, making it easy to integrate Python AI models with the datalayer. If needed, you can also create custom encoders."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5f0f14fb-8e79-4bc6-88af-1a800aecb8db",
   "metadata": {},
   "outputs": [],
   "source": [
    "!curl -O https://superduperdb-public.s3.eu-west-1.amazonaws.com/coco_sample.zip\n",
    "!unzip coco_sample.zip"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e41e6faa-6b83-46d8-ab37-6de6fd346ee7",
   "metadata": {},
   "outputs": [],
   "source": [
    "from superduperdb import Document\n",
    "from superduperdb.ext.pillow import pil_image as i\n",
    "import glob\n",
    "\n",
    "# Use glob to get a list of image file paths in the 'images_small' directory\n",
    "images = glob.glob('images_tiny/*.jpg')\n",
    "\n",
    "# Create a list of SuperDuperDB Document instances with image data\n",
    "# Note: The 'uri' parameter is set to the file URI using the 'file://' scheme\n",
    "# The list is limited to the first 500 images for demonstration purposes\n",
    "documents = [Document({'image': i(uri=f'file://{img}')}) for img in images][:500]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1b3a63bf-9e1f-4266-823a-7a2208937e01",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Access a random Document in the documents list, just to check\n",
    "documents[1]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c9c7e282",
   "metadata": {},
   "source": [
    "The wrapped python dictionaries may be inserted directly to the `Datalayer`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c32a91a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Insert the list of Document instances into a collection using SuperDuperDB\n",
    "# Specify the 'i' encoder (pil_image) for the 'image' field\n",
    "db.execute(collection.insert_many(documents), encoders=(i,))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "10d37264",
   "metadata": {},
   "source": [
    "You can verify that the images are correctly stored as follows:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7282a0fe",
   "metadata": {},
   "outputs": [],
   "source": [
    "x = db.execute(collection.find_one()).unpack()['image']\n",
    "\n",
    "# Resize the image for display while maintaining the aspect ratio and Display the resized image\n",
    "display(x.resize((300, 300 * (1+int(x.size[1] / x.size[0])))))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dab27b50",
   "metadata": {},
   "source": [
    "## Build Models\n",
    "\n",
    "Now, let's prepare the CLIP model for multimodal search. This involves two components: `text encoding` and `visual encoding`. Once both components are installed, you can perform searches using both images and text to find matching items."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "916792d3",
   "metadata": {},
   "outputs": [],
   "source": [
    "import clip\n",
    "from superduperdb import vector\n",
    "from superduperdb.ext.torch import TorchModel\n",
    "\n",
    "# Load the CLIP model and obtain the preprocessing function\n",
    "model, preprocess = clip.load(\"RN50\", device='cpu')\n",
    "\n",
    "# Define a vector with shape (1024,)\n",
    "e = vector(shape=(1024,))\n",
    "\n",
    "# Create a TorchModel for text encoding\n",
    "text_model = TorchModel(\n",
    "    identifier='clip_text',  # Unique identifier for the model\n",
    "    object=model,  # CLIP model\n",
    "    preprocess=lambda x: clip.tokenize(x)[0],  # Model input preprocessing using CLIP\n",
    "    postprocess=lambda x: x.tolist(),  # Convert the model output to a list\n",
    "    encoder=e,  # Vector encoder with shape (1024,)\n",
    "    forward_method='encode_text',  # Use the 'encode_text' method for forward pass\n",
    ")\n",
    "\n",
    "# Create a TorchModel for visual encoding\n",
    "visual_model = TorchModel(\n",
    "    identifier='clip_image',  # Unique identifier for the model\n",
    "    object=model.visual,  # Visual part of the CLIP model\n",
    "    preprocess=preprocess,  # Visual preprocessing using CLIP\n",
    "    postprocess=lambda x: x.tolist(),  # Convert the output to a list\n",
    "    encoder=e,  # Vector encoder with shape (1024,)\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b716bcb2",
   "metadata": {},
   "source": [
    "## Create a Vector-Search Index\n",
    "\n",
    "Now, let's create the index for vector-based searching. We'll register both models with the index simultaneously. Specify that the `visual_model` will be responsible for creating vectors in the database (`indexing_listener`). The `compatible_listener` indicates how an alternative model can be used to search the vectors, allowing multimodal search with models expecting different types of indexes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c4e0302c",
   "metadata": {},
   "outputs": [],
   "source": [
    "from superduperdb import VectorIndex\n",
    "from superduperdb import Listener\n",
    "\n",
    "# Create a VectorIndex and add it to the database\n",
    "db.add(\n",
    "    VectorIndex(\n",
    "        'my-index',  # Unique identifier for the VectorIndex\n",
    "        indexing_listener=Listener(\n",
    "            model=visual_model,  # Visual model for embeddings\n",
    "            key='image',  # Key field in documents for embeddings\n",
    "            select=collection.find(),  # Select the documents for indexing\n",
    "            predict_kwargs={'batch_size': 10},  # Prediction arguments for the indexing model\n",
    "        ),\n",
    "        compatible_listener=Listener(\n",
    "            # Create a listener to listen upcoming changes in databases\n",
    "            model=text_model,\n",
    "            key='text',\n",
    "            active=False,\n",
    "            select=None,\n",
    "        )\n",
    "    )\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "18971a6d",
   "metadata": {},
   "source": [
    "## Search Images Using Text\n",
    "\n",
    "Now we can demonstrate searching for images using text queries:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ab994b5e",
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.display import display\n",
    "from superduperdb import Document\n",
    "\n",
    "query_string = 'sports'\n",
    "\n",
    "# Execute the 'like' query using the VectorIndex 'my-index' and find the top 3 results\n",
    "out = db.execute(\n",
    "    collection.like(Document({'text': query_string}), vector_index='my-index', n=3).find({})\n",
    ")\n",
    "\n",
    "# Display the images from the search results\n",
    "for r in out:\n",
    "    x = r['image'].x\n",
    "    display(x.resize((300, int(300 * x.size[1] / x.size[0]))))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "94af81f3",
   "metadata": {},
   "source": [
    "Let's dig further"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b5e3ac22-044f-4675-976a-68ff9b59efe9",
   "metadata": {},
   "outputs": [],
   "source": [
    "img = db.execute(collection.find_one({}))['image']\n",
    "img.x"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2f4f6b7d",
   "metadata": {},
   "source": [
    "## Now let's try an image-2-image similarity search\n",
    "Perform a similarity search using the vector index 'my-index'\n",
    "Find the top 3 images similar to the input image 'img'\n",
    "Finally displaying the retrieved images while resizing them for better visualization."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c8569e4f-74f2-4ee5-9674-7829b2fcc62b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Execute the 'like' query using the VectorIndex 'my-index' to find similar images to the specified 'img'\n",
    "similar_images = db.execute(\n",
    "    collection.like(Document({'image': img}), vector_index='my-index', n=3).find({})\n",
    ")\n",
    "\n",
    "# Display the similar images from the search results\n",
    "for i in similar_images:\n",
    "    x = i['image'].x\n",
    "    display(x.resize((300, int(300 * x.size[1] / x.size[0]))))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
